{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64203dd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#| eval: false\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e32a6fbd",
   "metadata": {},
   "source": [
    "# Notebook distributed training\n",
    "> Using `Accelerate` to launch a training script from your notebook"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "---\n",
    "skip_exec: true\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2322476f",
   "metadata": {},
   "source": [
    "## Overview\n",
    "\n",
    "In this tutorial we will see how to use [Accelerate](https://github.com/huggingface/accelerate) to launch a training function on a distributed system, from inside your **notebook**! \n",
    "\n",
    "To keep it easy, this example will follow training PETs, showcasing how all it takes is 3 new lines of code to be on your way!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef9bb971",
   "metadata": {},
   "source": [
    "## Setting up imports and building the DataLoaders\n",
    "\n",
    "First, make sure that Accelerate is installed on your system by running:\n",
    "```bash\n",
    "pip install accelerate -U\n",
    "```\n",
    "\n",
    "In your code, along with the normal `from fastai.module.all import *` imports two new ones need to be added:\n",
    "```diff\n",
    "+ from fastai.distributed import *\n",
    "from fastai.vision.all import *\n",
    "from fastai.vision.models.xresnet import *\n",
    "\n",
    "+ from accelerate import notebook_launcher\n",
    "+ from accelerate.utils import write_basic_config\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6dde5b3",
   "metadata": {},
   "source": [
    "The first brings in the  `Learner.distrib_ctx` context manager. The second brings in Accelerate's [notebook_launcher](https://huggingface.co/docs/accelerate/launcher), the key function we will call to run what we want."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cfcb5de",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from fastai.vision.all import *\n",
    "from fastai.distributed import *\n",
    "from fastai.vision.models.xresnet import *\n",
    "\n",
    "from accelerate import notebook_launcher\n",
    "from accelerate.utils import write_basic_config"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c11155c3-8054-4168-a4e9-a28cc70e8893",
   "metadata": {},
   "source": [
    "We need to setup `Accelerate` to use all of our GPUs. We can do so quickly with `write_basic_config ()`:\n",
    "\n",
    ":::{.callout-note}\n",
    "\n",
    "Since this checks `torch.cuda.device_count`, you will need to restart your notebook and skip calling this again to continue. It only needs to be ran once! Also if you choose not to use this run `accelerate config` from the terminal and set `mixed_precision` to `no`\n",
    "\n",
    ":::"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72a1427a-106e-4cc4-9d5e-123b1bb3a082",
   "metadata": {},
   "outputs": [],
   "source": [
    "#from accelerate.utils import write_basic_config\n",
    "#write_basic_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8003b603",
   "metadata": {},
   "source": [
    "Next let's download some data to train on. You don't need to worry about using `rank0_first`, as since we're in our Jupyter Notebook it will only run on one process like normal:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "481ef1c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.PETS)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6250dcc",
   "metadata": {},
   "source": [
    "We wrap the creation of the `DataLoaders`, our `vision_learner`, and call to `fine_tune` inside of a `train` function. \n",
    "\n",
    ":::{.callout-note}\n",
    "\n",
    "It is important to **not** build the `DataLoaders` outside of the function, as absolutely *nothing* can be loaded onto CUDA beforehand.\n",
    "\n",
    ":::"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "055fc7cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_y(o): return o[0].isupper()\n",
    "def train(path):\n",
    "    dls = ImageDataLoaders.from_name_func(\n",
    "        path, get_image_files(path), valid_pct=0.2,\n",
    "        label_func=get_y, item_tfms=Resize(224))\n",
    "    learn = vision_learner(dls, resnet34, metrics=error_rate).to_fp16()\n",
    "    learn.fine_tune(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae6eece1",
   "metadata": {},
   "source": [
    "The last addition to the `train` function needed is to use our context manager before calling `fine_tune` and setting `in_notebook` to `True`:\n",
    "\n",
    ":::{.callout-note}\n",
    "\n",
    "for this example `sync_bn` is disabled for compatibility purposes with `torchvision`'s resnet34\n",
    "\n",
    ":::"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf0529db",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(path):\n",
    "    dls = ImageDataLoaders.from_name_func(\n",
    "        path, get_image_files(path), valid_pct=0.2,\n",
    "        label_func=get_y, item_tfms=Resize(224))\n",
    "    learn = vision_learner(dls, resnet34, metrics=error_rate).to_fp16()\n",
    "    with learn.distrib_ctx(sync_bn=False, in_notebook=True):\n",
    "        learn.fine_tune(1)\n",
    "    learn.export(\"pets\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e55fb03",
   "metadata": {},
   "source": [
    "Finally, just call `notebook_launcher`, passing in the training function, any arguments as a tuple, and the number of GPUs (processes) to use:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4af598f8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Launching training on 2 GPUs.\n",
      "Training Learner...\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>error_rate</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.342019</td>\n",
       "      <td>0.228441</td>\n",
       "      <td>0.105041</td>\n",
       "      <td>00:54</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>error_rate</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.197188</td>\n",
       "      <td>0.141764</td>\n",
       "      <td>0.062246</td>\n",
       "      <td>00:56</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "notebook_launcher(train, (path,), num_processes=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6b0d4f3",
   "metadata": {},
   "source": [
    "Afterwards we can import our exported `Learner`, save, or anything else we may want to do in our Jupyter Notebook outside of a distributed process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30eed343",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "('False', TensorBase(0), TensorBase([0.9718, 0.0282]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "imgs = get_image_files(path)\n",
    "learn = load_learner(path/'pets')\n",
    "learn.predict(imgs[0])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
